{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Boost\n",
    "\n",
    "“装袋”（bagging）和“提升”（boost）是构建组合模型的两种最主要的方法，所谓的组合模型是由多个基本模型构成的模型，组合模型的预测效果往往比任意一个基本模型的效果都要好。\n",
    "\n",
    "- 装袋：每个基本模型由从总体样本中随机抽样得到的不同数据集进行训练得到，通过重抽样得到不同训练数据集的过程称为装袋。\n",
    "\n",
    "- 提升：每个基本模型训练时的数据集采用不同权重，针对上一个基本模型分类错误的样本增加权重，使得新的模型重点关注误分类样本\n",
    "\n",
    "### AdaBoost\n",
    "\n",
    "AdaBoost是AdaptiveBoost的缩写，表明该算法是具有适应性的提升算法。\n",
    "\n",
    "算法的步骤如下：\n",
    "\n",
    "1）给每个训练样本（x1,x2,….,xN）分配权重，初始权重$w_{1}$均为1/N。\n",
    "\n",
    "2）针对带有权值的样本进行训练，得到模型$G_m$（初始模型为G1）。\n",
    "\n",
    "3）计算模型$G_m$的误分率$e_m=\\sum_{i=1}^Nw_iI(y_i\\not= G_m(x_i))$\n",
    "\n",
    "4）计算模型$G_m$的系数$\\alpha_m=0.5\\log[(1-e_m)/e_m]$\n",
    "\n",
    "5）根据误分率e和当前权重向量$w_m$更新权重向量$w_{m+1}$。\n",
    "\n",
    "6）计算组合模型$f(x)=\\sum_{m=1}^M\\alpha_mG_m(x_i)$的误分率。\n",
    "\n",
    "7）当组合模型的误分率或迭代次数低于一定阈值，停止迭代；否则，回到步骤2）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection  import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# data\n",
    "def create_data():\n",
    "    iris = load_iris()\n",
    "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "    df['label'] = iris.target\n",
    "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
    "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "    for i in range(len(data)):\n",
    "        if data[i,-1] == 0:\n",
    "            data[i,-1] = -1\n",
    "    # print(data)\n",
    "    return data[:,:2], data[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x204795c3940>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGZ9JREFUeJzt3X9sHOWdx/H394yv8bWAReMWsJNLCihqSSLSugTICXGg\nXkqaQoRQlAiKQhE5ELpS0aNqKtQfqBJISLRQdEQBdBTBBeVoGihHgjgoKkUklZMg5y4pKhxtY8MV\nE5TQHKYE93t/7Dqx12vvzu6O93me/bwky97Zyfj7zMA3m5nPPGPujoiIpOWvml2AiIg0npq7iEiC\n1NxFRBKk5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSdBx1a5oZm1AHzDo7stL3rsAeBx4\nvbhos7vfOtX2Zs6c6XPmzMlUrIhIq9u5c+fb7t5Vab2qmztwI7APOGGS918obfpTmTNnDn19fRl+\nvYiImNnvq1mvqtMyZtYDfAm4v56iRERkelR7zv1HwDeBv0yxznlm1m9mW83szHIrmNlaM+szs76h\noaGstYqISJUqNnczWw685e47p1htFzDb3RcCPwa2lFvJ3Te4e6+793Z1VTxlJCIiNarmnPsS4BIz\nWwbMAE4ws4fd/crRFdz93TE/P2Vm/2JmM9397caXLCJSnyNHjjAwMMD777/f7FImNWPGDHp6emhv\nb6/pz1ds7u6+DlgHR1Mx/zy2sReXnwz80d3dzM6m8C+CAzVVJCKSs4GBAY4//njmzJmDmTW7nAnc\nnQMHDjAwMMDcuXNr2kaWtMw4ZnZdsYj1wOXA9Wb2ITAMrHI9BUREAvX+++8H29gBzIyPf/zj1HNt\nMlNzd/fngeeLP68fs/we4J6aqxAJ2Jbdg9zx9Cu8cXCYUzs7uHnpPFYs6m52WVKnUBv7qHrrq/mT\nu0gr2LJ7kHWb9zB8ZASAwYPDrNu8B0ANXoKm6QdEpnDH068cbeyjho+McMfTrzSpIknFtm3bmDdv\nHqeffjq33357w7ev5i4yhTcODmdaLlKNkZERbrjhBrZu3crevXvZuHEje/fubejv0GkZkSmc2tnB\nYJlGfmpnRxOqkWZp9HWXX//615x++ul86lOfAmDVqlU8/vjjfOYzn2lUyfrkLjKVm5fOo6O9bdyy\njvY2bl46r0kVyXQbve4yeHAY59h1ly27B2ve5uDgILNmzTr6uqenh8HB2rdXjpq7yBRWLOrmtssW\n0N3ZgQHdnR3cdtkCXUxtIbFed9FpGZEKVizqVjNvYXlcd+nu7mb//v1HXw8MDNDd3dj/xvTJXURk\nCpNdX6nnusvnP/95fvvb3/L666/zwQcf8Oijj3LJJZfUvL1y1NxFRKaQx3WX4447jnvuuYelS5fy\n6U9/mpUrV3LmmWUn0639dzR0ayIiiRk9Jdfou5SXLVvGsmXLGlFiWWruIiIVxHjdRadlREQSpOYu\nIpIgNXcRkQSpuYuIJEjNXUQkQWrukowtuwdZcvtzzP3Wf7Dk9ufqmvtDJG9f/epX+cQnPsH8+fNz\n2b6auyQhj8mdRPK0Zs0atm3bltv21dwlCbFO7iSR6N8EP5wP3+ssfO/fVPcmzz//fE466aQGFFee\nbmKSJOihGpKb/k3w86/BkeJ/S4f2F14DLFzZvLoq0Cd3SUIekzuJAPDsrcca+6gjw4XlAVNzlyTo\noRqSm0MD2ZYHQqdlJAl5Te4kwok9hVMx5ZYHTM1dkhHj5E4SgYu+M/6cO0B7R2F5HVavXs3zzz/P\n22+/TU9PD9///ve55ppr6iz2GDV3qVujHx4sEpTRi6bP3lo4FXNiT6Gx13kxdePGjQ0obnJq7lKX\n0Xz5aAxxNF8OqMFLOhauDDoZU44uqEpdlC8XCZOau9RF+XKJlbs3u4Qp1VufmrvURflyidGMGTM4\ncOBAsA3e3Tlw4AAzZsyoeRs65y51uXnpvHHn3EH5cglfT08PAwMDDA0NNbuUSc2YMYOentrjlmru\nUhflyyVG7e3tzJ07t9ll5Krq5m5mbUAfMOjuy0veM+AuYBnwHrDG3Xc1slAJl/LlIuHJ8sn9RmAf\ncEKZ9y4Gzih+LQbuLX4XaSnK/EsoqrqgamY9wJeA+ydZ5VLgIS/YDnSa2SkNqlEkCppTXkJSbVrm\nR8A3gb9M8n43MHbyhYHiMpGWocy/hKRiczez5cBb7r6z3l9mZmvNrM/M+kK+Si1SC2X+JSTVfHJf\nAlxiZr8DHgUuNLOHS9YZBGaNed1TXDaOu29w91537+3q6qqxZJEwKfMvIanY3N19nbv3uPscYBXw\nnLtfWbLaE8BVVnAOcMjd32x8uSLh0pzyEpKac+5mdh2Au68HnqIQg3yVQhTy6oZUJxIRZf4lJNas\n2297e3u9r6+vKb9bRCRWZrbT3Xsrrac7VCVYt2zZw8Yd+xlxp82M1Ytn8YMVC5pdlkgU1NwlSLds\n2cPD2/9w9PWI+9HXavAilWlWSAnSxh1lnlk5xXIRGU/NXYI0Msm1oMmWi8h4au4SpDazTMtFZDw1\ndwnS6sWzMi0XkfF0QVWCNHrRVGkZkdoo5y4iEhHl3KUuV9z3Ei++9s7R10tOO4lHrj23iRU1j+Zo\nlxjpnLtMUNrYAV587R2uuO+lJlXUPJqjXWKl5i4TlDb2SstTpjnaJVZq7iJT0BztEis1d5EpaI52\niZWau0yw5LSTMi1PmeZol1ipucsEj1x77oRG3qppmRWLurntsgV0d3ZgQHdnB7ddtkBpGQmecu4i\nIhFRzl3qkle2O8t2lS8XqZ2au0wwmu0ejQCOZruBupprlu3mVYNIq9A5d5kgr2x3lu0qXy5SHzV3\nmSCvbHeW7SpfLlIfNXeZIK9sd5btKl8uUh81d5kgr2x3lu0qXy5SH11QlQlGL1g2OqmSZbt51SDS\nKpRzFxGJiHLuOYsxgx1jzSJSGzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7dTcaxBjBjvGmkWkdmru\nNYgxgx1jzSJSOzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7Srm3M1sBvBL4CMU/jJ4zN2/W7LOBcDj\nwOvFRZvd/daptqucu4hIdo3Muf8ZuNDdD5tZO/ArM9vq7ttL1nvB3ZfXUqxMj1u27GHjjv2MuNNm\nxurFs/jBigV1rxtKfj6UOkRCULG5e+Gj/eHiy/biV3Nua5Wa3bJlDw9v/8PR1yPuR1+XNu0s64aS\nnw+lDpFQVHVB1czazOxl4C3gGXffUWa188ys38y2mtmZDa1S6rZxx/6ql2dZN5T8fCh1iISiqubu\n7iPufhbQA5xtZvNLVtkFzHb3hcCPgS3ltmNma82sz8z6hoaG6qlbMhqZ5NpKueVZ1g0lPx9KHSKh\nyBSFdPeDwC+AL5Ysf9fdDxd/fgpoN7OZZf78Bnfvdfferq6uOsqWrNrMql6eZd1Q8vOh1CESiorN\n3cy6zKyz+HMH8AXgNyXrnGxW+D/fzM4ubvdA48uVWq1ePKvq5VnWDSU/H0odIqGoJi1zCvATM2uj\n0LQ3ufuTZnYdgLuvBy4HrjezD4FhYJU3ay5hKWv0Qmg1CZgs64aSnw+lDpFQaD53EZGIaD73nOWV\nqc6SL89z21nGF+O+iE7/Jnj2Vjg0ACf2wEXfgYUrm12VBEzNvQZ5Zaqz5Mvz3HaW8cW4L6LTvwl+\n/jU4Ukz+HNpfeA1q8DIpTRxWg7wy1Vny5XluO8v4YtwX0Xn21mONfdSR4cJykUmoudcgr0x1lnx5\nntvOMr4Y90V0Dg1kWy6CmntN8spUZ8mX57ntLOOLcV9E58SebMtFUHOvSV6Z6iz58jy3nWV8Me6L\n6Fz0HWgv+cuyvaOwXGQSuqBag7wy1Vny5XluO8v4YtwX0Rm9aKq0jGSgnLuISESUc5cJQsiuS+SU\nt4+GmnuLCCG7LpFT3j4quqDaIkLIrkvklLePipp7iwghuy6RU94+KmruLSKE7LpETnn7qKi5t4gQ\nsusSOeXto6ILqi0ihOy6RE55+6go5y4iEhHl3Ivyymtn2W4o85Irux6Y1DPjqY8viybsi6Sbe155\n7SzbDWVecmXXA5N6Zjz18WXRpH2R9AXVvPLaWbYbyrzkyq4HJvXMeOrjy6JJ+yLp5p5XXjvLdkOZ\nl1zZ9cCknhlPfXxZNGlfJN3c88prZ9luKPOSK7semNQz46mPL4sm7Yukm3teee0s2w1lXnJl1wOT\nemY89fFl0aR9kfQF1bzy2lm2G8q85MquByb1zHjq48uiSftCOXcRkYgo556zEPLzV9z3Ei++9s7R\n10tOO4lHrj237hpEkvLkTbDzQfARsDb43BpYfmf92w08x5/0Ofe8jGbGBw8O4xzLjG/ZPTht2y1t\n7AAvvvYOV9z3Ul01iCTlyZug74FCY4fC974HCsvrMZpdP7Qf8GPZ9f5NdZfcKGruNQghP1/a2Cst\nF2lJOx/MtrxaEeT41dxrEEJ+XkSq4CPZllcrghy/mnsNQsjPi0gVrC3b8mpFkONXc69BCPn5Jaed\nVHYbky0XaUmfW5NtebUiyPGruddgxaJubrtsAd2dHRjQ3dnBbZctaEh+vtrtPnLtuRMaudIyIiWW\n3wm91xz7pG5thdf1pmUWroQv3w0nzgKs8P3LdweVllHOXUQkIg3LuZvZDOCXwEeK6z/m7t8tWceA\nu4BlwHvAGnffVUvhlWTNl8c2h3mWud9T3xe55oizZJ/zqiPP8QWewa5L1rGlvC+mUM1NTH8GLnT3\nw2bWDvzKzLa6+/Yx61wMnFH8WgzcW/zeUFnnJI9tDvMsc7+nvi9ynQN7NPs8ajT7DBMbfF515Dm+\nlOdSzzq2lPdFBRXPuXvB4eLL9uJX6bmcS4GHiutuBzrN7JTGlpo9Xx7bHOZZ5n5PfV/kmiPOkn3O\nq448xxdBBrtmWceW8r6ooKoLqmbWZmYvA28Bz7j7jpJVuoGxHWiguKx0O2vNrM/M+oaGhjIXmzUH\nHltuPMvc76nvi1xzxFmyz3nVkef4Ishg1yzr2FLeFxVU1dzdfcTdzwJ6gLPNbH4tv8zdN7h7r7v3\ndnV1Zf7zWXPgseXGs8z9nvq+yDVHnCX7nFcdeY4vggx2zbKOLeV9UUGmKKS7HwR+AXyx5K1BYOwE\n5T3FZQ2VNV8e2xzmWeZ+T31f5JojzpJ9zquOPMcXQQa7ZlnHlvK+qKBiczezLjPrLP7cAXwB+E3J\nak8AV1nBOcAhd3+z0cVmzZfnlUfPyw9WLODKc2Yf/aTeZsaV58wum5ZJfV/kmiPOkn3Oq448xxdB\nBrtmWceW8r6ooGLO3cwWAj8B2ij8ZbDJ3W81s+sA3H19MQp5D4VP9O8BV7v7lCF25dxFRLJrWM7d\n3fuBRWWWrx/zswM3ZC1SRETykfzDOqK7cUemR5YbW0K4CSbPG3diu0krhOMRgaSbe3Q37sj0yHJj\nSwg3weR5405sN2mFcDwikfTEYdHduCPTI8uNLSHcBJPnjTux3aQVwvGIRNLNPbobd2R6ZLmxJYSb\nYPK8cSe2m7RCOB6RSLq5R3fjjkyPLDe2hHATTJ437sR2k1YIxyMSSTf36G7ckemR5caWEG6CyfPG\nndhu0grheEQi6eYe3Y07Mj2y3NgSwk0wed64E9tNWiEcj0joYR0iIhFp2E1MIi0vy4M9QhFbzaFk\n10OpowHU3EWmkuXBHqGIreZQsuuh1NEgSZ9zF6lblgd7hCK2mkPJrodSR4OouYtMJcuDPUIRW82h\nZNdDqaNB1NxFppLlwR6hiK3mULLrodTRIGruIlPJ8mCPUMRWcyjZ9VDqaBA1d5GpZHmwRyhiqzmU\n7HoodTSIcu4iIhFRzl2mT4zZ4LxqzitfHuM+lqZSc5f6xJgNzqvmvPLlMe5jaTqdc5f6xJgNzqvm\nvPLlMe5jaTo1d6lPjNngvGrOK18e4z6WplNzl/rEmA3Oq+a88uUx7mNpOjV3qU+M2eC8as4rXx7j\nPpamU3OX+sSYDc6r5rzy5THuY2k65dxFRCJSbc5dn9wlHf2b4Ifz4Xudhe/9m6Z/u3nVIJKRcu6S\nhryy4Fm2qzy6BESf3CUNeWXBs2xXeXQJiJq7pCGvLHiW7SqPLgFRc5c05JUFz7Jd5dElIGrukoa8\nsuBZtqs8ugREzV3SkFcWPMt2lUeXgFTMuZvZLOAh4JOAAxvc/a6SdS4AHgdeLy7a7O5TXkVSzl1E\nJLtGzuf+IfANd99lZscDO83sGXffW7LeC+6+vJZiJUAxzh+epeYYxxcC7bdoVGzu7v4m8Gbx5z+Z\n2T6gGyht7pKKGPPayqPnT/stKpnOuZvZHGARsKPM2+eZWb+ZbTWzMxtQmzRLjHlt5dHzp/0Wlarv\nUDWzjwE/Bb7u7u+WvL0LmO3uh81sGbAFOKPMNtYCawFmz55dc9GSsxjz2sqj50/7LSpVfXI3s3YK\njf0Rd99c+r67v+vuh4s/PwW0m9nMMuttcPded+/t6uqqs3TJTYx5beXR86f9FpWKzd3MDHgA2Ofu\nZecuNbOTi+thZmcXt3ugkYXKNIoxr608ev6036JSzWmZJcBXgD1m9nJx2beB2QDuvh64HLjezD4E\nhoFV3qy5hKV+oxfHYkpFZKk5xvGFQPstKprPXUQkIo3MuUuolDke78mbYOeDhQdSW1vh8Xb1PgVJ\nJFJq7rFS5ni8J2+CvgeOvfaRY6/V4KUFaW6ZWClzPN7OB7MtF0mcmnuslDkez0eyLRdJnJp7rJQ5\nHs/asi0XSZyae6yUOR7vc2uyLRdJnJp7rDR3+HjL74Tea459Ure2wmtdTJUWpZy7iEhElHOvwZbd\ng9zx9Cu8cXCYUzs7uHnpPFYs6m52WY2Tei4+9fGFQPs4GmruRVt2D7Ju8x6GjxTSFYMHh1m3eQ9A\nGg0+9Vx86uMLgfZxVHTOveiOp1852thHDR8Z4Y6nX2lSRQ2Wei4+9fGFQPs4KmruRW8cHM60PDqp\n5+JTH18ItI+jouZedGpnR6bl0Uk9F5/6+EKgfRwVNfeim5fOo6N9/A0vHe1t3Lx0XpMqarDUc/Gp\njy8E2sdR0QXVotGLpsmmZVKfizv18YVA+zgqyrmLiESk2py7TsuIxKB/E/xwPnyvs/C9f1Mc25am\n0WkZkdDlmS9Xdj1Z+uQuEro88+XKridLzV0kdHnmy5VdT5aau0jo8syXK7ueLDV3kdDlmS9Xdj1Z\nau4ioctz7n49FyBZyrmLiEREOXcRkRam5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSZCa\nu4hIgio2dzObZWa/MLO9ZvbfZnZjmXXMzO42s1fNrN/MPptPuVIXzdst0jKqmc/9Q+Ab7r7LzI4H\ndprZM+6+d8w6FwNnFL8WA/cWv0soNG+3SEup+Mnd3d90913Fn/8E7ANKHyx6KfCQF2wHOs3slIZX\nK7XTvN0iLSXTOXczmwMsAnaUvNUN7B/zeoCJfwFgZmvNrM/M+oaGhrJVKvXRvN0iLaXq5m5mHwN+\nCnzd3d+t5Ze5+wZ373X33q6urlo2IbXSvN0iLaWq5m5m7RQa+yPuvrnMKoPArDGve4rLJBSat1uk\npVSTljHgAWCfu985yWpPAFcVUzPnAIfc/c0G1in10rzdIi2lmrTMEuArwB4ze7m47NvAbAB3Xw88\nBSwDXgXeA65ufKlSt4Ur1cxFWkTF5u7uvwKswjoO3NCookREpD66Q1VEJEFq7iIiCVJzFxFJkJq7\niEiC1NxFRBKk5i4ikiA1dxGRBFkhot6EX2w2BPy+Kb+8spnA280uIkcaX7xSHhtofNX4W3evODlX\n05p7yMysz917m11HXjS+eKU8NtD4GkmnZUREEqTmLiKSIDX38jY0u4CcaXzxSnlsoPE1jM65i4gk\nSJ/cRUQS1NLN3czazGy3mT1Z5r0LzOyQmb1c/IrqkUVm9jsz21Osva/M+2Zmd5vZq2bWb2afbUad\ntapifLEfv04ze8zMfmNm+8zs3JL3Yz9+lcYX7fEzs3lj6n7ZzN41s6+XrJP78avmYR0puxHYB5ww\nyfsvuPvyaayn0f7e3SfL1F4MnFH8WgzcW/wek6nGB3Efv7uAbe5+uZn9NfA3Je/HfvwqjQ8iPX7u\n/gpwFhQ+QFJ45OjPSlbL/fi17Cd3M+sBvgTc3+xamuRS4CEv2A50mtkpzS5KwMxOBM6n8HhL3P0D\ndz9Yslq0x6/K8aXiIuA1dy+9YTP349eyzR34EfBN4C9TrHNe8Z9MW83szGmqq1Ec+E8z22lma8u8\n3w3sH/N6oLgsFpXGB/Eev7nAEPCvxdOG95vZR0vWifn4VTM+iPf4jbUK2Fhmee7HryWbu5ktB95y\n951TrLYLmO3uC4EfA1umpbjG+Tt3P4vCP/9uMLPzm11Qg1UaX8zH7zjgs8C97r4I+D/gW80tqaGq\nGV/Mxw+A4ummS4B/b8bvb8nmTuGh35eY2e+AR4ELzezhsSu4+7vufrj481NAu5nNnPZKa+Tug8Xv\nb1E433d2ySqDwKwxr3uKy6JQaXyRH78BYMDddxRfP0ahGY4V8/GrOL7Ij9+oi4Fd7v7HMu/lfvxa\nsrm7+zp373H3ORT+2fScu185dh0zO9nMrPjz2RT21YFpL7YGZvZRMzt+9GfgH4D/KlntCeCq4lX7\nc4BD7v7mNJdak2rGF/Pxc/f/Bfab2bzioouAvSWrRXv8qhlfzMdvjNWUPyUD03D8Wj0tM46ZXQfg\n7uuBy4HrzexDYBhY5fHc8fVJ4GfF/zeOA/7N3beVjO8pYBnwKvAecHWTaq1FNeOL+fgB/BPwSPGf\n9v8DXJ3Q8YPK44v6+BU/dHwB+Mcxy6b1+OkOVRGRBLXkaRkRkdSpuYuIJEjNXUQkQWruIiIJUnMX\nEUmQmruISILU3EVEEqTmLiKSoP8H2fNC9uxjMHwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x204791b1a90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X[:50,0],X[:50,1], label='0')\n",
    "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "\n",
    "### AdaBoost in Python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class AdaBoost:\n",
    "    def __init__(self, n_estimators=50, learning_rate=1.0):\n",
    "        self.clf_num = n_estimators\n",
    "        self.learning_rate = learning_rate\n",
    "    \n",
    "    def init_args(self, datasets, labels):\n",
    "        \n",
    "        self.X = datasets\n",
    "        self.Y = labels\n",
    "        self.M, self.N = datasets.shape\n",
    "        \n",
    "        # 弱分类器数目和集合\n",
    "        self.clf_sets = []\n",
    "        \n",
    "        # 初始化weights\n",
    "        self.weights = [1.0/self.M]*self.M\n",
    "        \n",
    "        # G(x)系数 alpha\n",
    "        self.alpha = []\n",
    "        \n",
    "    def _G(self, features, labels, weights):\n",
    "        m = len(features)\n",
    "        error = 100000.0 # 无穷大\n",
    "        best_v = 0.0\n",
    "        # 单维features\n",
    "        features_min = min(features)\n",
    "        features_max = max(features)\n",
    "        n_step = (features_max - features_min + self.learning_rate) // self.learning_rate\n",
    "        # print('n_step:{}'.format(n_step))\n",
    "        direct, compare_array = None, None\n",
    "        for i in range(1, int(n_step)):\n",
    "            v = features_min + self.learning_rate * i\n",
    "            \n",
    "            if v not in features:\n",
    "                # 误分类计算\n",
    "                compare_array_positive = np.array([1 if features[k] > v else -1 for k in range(m)])\n",
    "                weight_error_positive = sum([weights[k] for k in range(m) if compare_array_positive[k] != labels[k]])\n",
    "                \n",
    "                compare_array_nagetive = np.array([-1 if features[k] > v else 1 for k in range(m)])\n",
    "                weight_error_nagetive = sum([weights[k] for k in range(m) if compare_array_nagetive[k] != labels[k]])\n",
    "\n",
    "                if weight_error_positive < weight_error_nagetive:\n",
    "                    weight_error = weight_error_positive\n",
    "                    _compare_array = compare_array_positive\n",
    "                    direct = 'positive'\n",
    "                else:\n",
    "                    weight_error = weight_error_nagetive\n",
    "                    _compare_array = compare_array_nagetive\n",
    "                    direct = 'nagetive'\n",
    "                    \n",
    "                # print('v:{} error:{}'.format(v, weight_error))\n",
    "                if weight_error < error:\n",
    "                    error = weight_error\n",
    "                    compare_array = _compare_array\n",
    "                    best_v = v\n",
    "        return best_v, direct, error, compare_array\n",
    "        \n",
    "    # 计算alpha\n",
    "    def _alpha(self, error):\n",
    "        return 0.5 * np.log((1-error)/error)\n",
    "    \n",
    "    # 规范化因子\n",
    "    def _Z(self, weights, a, clf):\n",
    "        return sum([weights[i]*np.exp(-1*a*self.Y[i]*clf[i]) for i in range(self.M)])\n",
    "        \n",
    "    # 权值更新\n",
    "    def _w(self, a, clf, Z):\n",
    "        for i in range(self.M):\n",
    "            self.weights[i] = self.weights[i]*np.exp(-1*a*self.Y[i]*clf[i])/ Z\n",
    "    \n",
    "    # G(x)的线性组合\n",
    "    def _f(self, alpha, clf_sets):\n",
    "        pass\n",
    "    \n",
    "    def G(self, x, v, direct):\n",
    "        if direct == 'positive':\n",
    "            return 1 if x > v else -1 \n",
    "        else:\n",
    "            return -1 if x > v else 1 \n",
    "    \n",
    "    def fit(self, X, y):\n",
    "        self.init_args(X, y)\n",
    "        \n",
    "        for epoch in range(self.clf_num):\n",
    "            best_clf_error, best_v, clf_result = 100000, None, None\n",
    "            # 根据特征维度, 选择误差最小的\n",
    "            for j in range(self.N):\n",
    "                features = self.X[:, j]\n",
    "                # 分类阈值，分类误差，分类结果\n",
    "                v, direct, error, compare_array = self._G(features, self.Y, self.weights)\n",
    "                \n",
    "                if error < best_clf_error:\n",
    "                    best_clf_error = error\n",
    "                    best_v = v\n",
    "                    final_direct = direct\n",
    "                    clf_result = compare_array\n",
    "                    axis = j\n",
    "                    \n",
    "                # print('epoch:{}/{} feature:{} error:{} v:{}'.format(epoch, self.clf_num, j, error, best_v))\n",
    "                if best_clf_error == 0:\n",
    "                    break\n",
    "                \n",
    "            # 计算G(x)系数a\n",
    "            a = self._alpha(best_clf_error)\n",
    "            self.alpha.append(a)\n",
    "            # 记录分类器\n",
    "            self.clf_sets.append((axis, best_v, final_direct))\n",
    "            # 规范化因子\n",
    "            Z = self._Z(self.weights, a, clf_result)\n",
    "            # 权值更新\n",
    "            self._w(a, clf_result, Z)\n",
    "            \n",
    "#             print('classifier:{}/{} error:{:.3f} v:{} direct:{} a:{:.5f}'.format(epoch+1, self.clf_num, error, best_v, final_direct, a))\n",
    "#             print('weight:{}'.format(self.weights))\n",
    "#             print('\\n')\n",
    "    \n",
    "    def predict(self, feature):\n",
    "        result = 0.0\n",
    "        for i in range(len(self.clf_sets)):\n",
    "            axis, clf_v, direct = self.clf_sets[i]\n",
    "            f_input = feature[axis]\n",
    "            result += self.alpha[i] * self.G(f_input, clf_v, direct)\n",
    "        # sign\n",
    "        return 1 if result > 0 else -1\n",
    "    \n",
    "    def score(self, X_test, y_test):\n",
    "        right_count = 0\n",
    "        for i in range(len(X_test)):\n",
    "            feature = X_test[i]\n",
    "            if self.predict(feature) == y_test[i]:\n",
    "                right_count += 1\n",
    "        \n",
    "        return right_count / len(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 例8.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X = np.arange(10).reshape(10, 1)\n",
    "y = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "classifier:1/3 error:0.300 v:2.5 direct:nagetive a:0.42365\n",
      "weight:[0.071428571428571425, 0.071428571428571425, 0.071428571428571425, 0.071428571428571425, 0.071428571428571425, 0.071428571428571425, 0.16666666666666663, 0.16666666666666663, 0.16666666666666663, 0.071428571428571425]\n",
      "\n",
      "\n",
      "classifier:2/3 error:0.214 v:8.5 direct:nagetive a:0.64964\n",
      "weight:[0.045454545454545463, 0.045454545454545463, 0.045454545454545463, 0.16666666666666669, 0.16666666666666669, 0.16666666666666669, 0.10606060606060606, 0.10606060606060606, 0.10606060606060606, 0.045454545454545463]\n",
      "\n",
      "\n",
      "classifier:3/3 error:0.182 v:5.5 direct:nagetive a:0.75204\n",
      "weight:[0.12499999999999996, 0.12499999999999996, 0.12499999999999996, 0.10185185185185185, 0.10185185185185185, 0.10185185185185185, 0.064814814814814797, 0.064814814814814797, 0.064814814814814797, 0.12499999999999996]\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "clf = AdaBoost(n_estimators=3, learning_rate=0.5)\n",
    "clf.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8484848484848485"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf = AdaBoost(n_estimators=10, learning_rate=0.2)\n",
    "clf.fit(X_train, y_train)\n",
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average score:63.061%\n"
     ]
    }
   ],
   "source": [
    "# 100次结果\n",
    "result = []\n",
    "for i in range(1, 101):\n",
    "    X, y = create_data()\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)\n",
    "    clf = AdaBoost(n_estimators=100, learning_rate=0.2)\n",
    "    clf.fit(X_train, y_train)\n",
    "    r = clf.score(X_test, y_test)\n",
    "    # print('{}/100 score：{}'.format(i, r))\n",
    "    result.append(r)\n",
    "\n",
    "print('average score:{:.3f}%'.format(sum(result)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-----\n",
    "# sklearn.ensemble.AdaBoostClassifier\n",
    "\n",
    "- algorithm：这个参数只有AdaBoostClassifier有。主要原因是scikit-learn实现了两种Adaboost分类算法，SAMME和SAMME.R。两者的主要区别是弱学习器权重的度量，SAMME使用了和我们的原理篇里二元分类Adaboost算法的扩展，即用对样本集分类效果作为弱学习器权重，而SAMME.R使用了对样本集分类的预测概率大小来作为弱学习器权重。由于SAMME.R使用了概率度量的连续值，迭代一般比SAMME快，因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。我们一般使用默认的SAMME.R就够了，但是要注意的是使用了SAMME.R， 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。\n",
    "\n",
    "- n_estimators： AdaBoostClassifier和AdaBoostRegressor都有，就是我们的弱学习器的最大迭代次数，或者说最大的弱学习器的个数。一般来说n_estimators太小，容易欠拟合，n_estimators太大，又容易过拟合，一般选择一个适中的数值。默认是50。在实际调参的过程中，我们常常将n_estimators和下面介绍的参数learning_rate一起考虑。\n",
    "\n",
    "-  learning_rate:  AdaBoostClassifier和AdaBoostRegressor都有，即每个弱学习器的权重缩减系数ν\n",
    "\n",
    "- base_estimator：AdaBoostClassifier和AdaBoostRegressor都有，即我们的弱分类学习器或者弱回归学习器。理论上可以选择任何一个分类或者回归学习器，不过需要支持样本权重。我们常用的一般是CART决策树或者神经网络MLP。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None,\n",
       "          learning_rate=0.5, n_estimators=100, random_state=None)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.ensemble import AdaBoostClassifier\n",
    "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n",
    "clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.90909090909090906"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
